# -*- coding: utf-8 -*-
# !/usr/bin/env python
"""
-------------------------------------------------
   File Name：     model
   Description :   
   Author :       lth
   date：          2022/2/19
-------------------------------------------------
   Change Activity:
                   2022/2/19 6:13: create this script
-------------------------------------------------
"""
__author__ = 'lth'

import torch
from torch import nn

from common import resnet50, resnet50_Decoder, resnet50_Head


class CenterNet(nn.Module):
    def __init__(self, class_num, pretrained=True):
        super(CenterNet, self).__init__()
        self.backbone = resnet50(pretrained)
        self.decoder = resnet50_Decoder(2048)
        self.head = resnet50_Head(channel=64, num_classes=class_num)

    def forward(self, x):
        feature = self.backbone(x)
        feature_decode = self.decoder(feature)
        return self.head(feature_decode)

if __name__=="__main__":
    model=CenterNet(2)
    dummy_input=torch.ones([1,3,512,512])

    print(model(dummy_input))
